Twin Delayed DDPG (TD3) — from scratch in PyTorch#

TD3 (Fujimoto, van Hoof, Meger, 2018) is a deterministic actor-critic algorithm for continuous control. It improves DDPG with three small but crucial modifications:

  1. Twin critics: learn two Q-functions and use the minimum in the bootstrap target.

  2. Target policy smoothing: add clipped noise to the target action when computing the target Q.

  3. Delayed policy updates: update the actor (and target networks) less often than the critics.

In this notebook we:

  • write the TD3 update equations precisely (LaTeX)

  • implement TD3 at a low level in PyTorch (no RL libraries)

  • train on a Gymnasium environment and plot episodic returns (Plotly)


Learning goals#

  • Understand why DDPG overestimates and how TD3 fixes it

  • Implement replay buffer + target networks + twin critics + delayed updates

  • Train a working agent and visualize learning curves

Prerequisites#

  • Basic PyTorch (modules, optimizers, autograd)

  • Q-learning / bootstrapping and the Bellman equation

  • Actor-critic idea (policy + value function)

  • Continuous action spaces (e.g., Pendulum)

import copy
import time

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import os
import plotly.io as pio

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    TORCH_AVAILABLE = True
except Exception as e:
    TORCH_AVAILABLE = False
    _TORCH_IMPORT_ERROR = e

try:
    import gymnasium as gym

    GYM_AVAILABLE = True
except Exception as e:
    GYM_AVAILABLE = False
    _GYM_IMPORT_ERROR = e


pio.templates.default = 'plotly_white'
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)

SEED = 42
rng = np.random.default_rng(SEED)
# --- Run configuration ---
FAST_RUN = True

ENV_ID = 'Pendulum-v1'

TOTAL_TIMESTEPS = 10_000 if FAST_RUN else 200_000
START_STEPS = 1_000 if FAST_RUN else 10_000
UPDATE_AFTER = 1_000 if FAST_RUN else 10_000

BATCH_SIZE = 256
BUFFER_SIZE = 200_000

# TD3 hyperparameters
GAMMA = 0.99
TAU = 0.005

ACTOR_LR = 1e-3
CRITIC_LR = 1e-3

POLICY_DELAY = 2
TARGET_POLICY_NOISE = 0.2
TARGET_NOISE_CLIP = 0.5

EXPLORATION_NOISE = 0.1

HIDDEN_SIZES = (256, 256)

1) TD3: the exact updates (twin critics + delayed actor)#

We use:

  • deterministic policy (actor) \(a = \pi_\phi(s)\)

  • two critics \(Q_{\theta_1}(s,a)\) and \(Q_{\theta_2}(s,a)\)

  • target networks \((\phi', \theta_1', \theta_2')\) updated by Polyak averaging

Given a transition \((s,a,r,s',\text{terminal})\) sampled from the replay buffer, TD3 builds the target in three steps.

1. Target policy smoothing#

TD3 does not evaluate the target critics at the raw target action \(\pi_{\phi'}(s')\). Instead it adds clipped Gaussian noise:

\[ \tilde a = \pi_{\phi'}(s') + \epsilon,\qquad \epsilon \sim \mathrm{clip}(\mathcal N(0, \sigma^2),\,-c,\,+c) \]
\[ \tilde a \leftarrow \mathrm{clip}(\tilde a, a_{\min}, a_{\max}) \]

Intuition: this makes the target Q-value less sensitive to small action errors and prevents the critic from exploiting sharp, unrealistic peaks in \(Q\).

2. Twin critics (min target)#

Compute both target Q-values and take the minimum:

\[ y = r + \gamma (1-\text{terminal})\,\min\Big( Q_{\theta_1'}(s', \tilde a),\; Q_{\theta_2'}(s', \tilde a)\Big) \]

Each critic minimizes an MSE to this same target:

\[ L(\theta_i) = \mathbb E\big[(Q_{\theta_i}(s,a)-y)^2\big],\qquad i\in\{1,2\} \]

Taking the minimum is a simple bias-reduction trick: it turns DDPG’s optimistic target into a more conservative estimate, reducing overestimation error.

3. Delayed policy updates#

The critics are updated every gradient step. The actor is updated only every \(d\) critic updates (e.g. \(d=2\)):

\[ \max_\phi\; J(\phi) = \mathbb E\big[ Q_{\theta_1}(s, \pi_\phi(s)) \big] \]

In code we minimize the negative:

\[ L_\pi(\phi) = -\mathbb E\big[ Q_{\theta_1}(s, \pi_\phi(s)) \big] \]

When (and only when) we update the actor, we also update all target networks with Polyak averaging:

\[ \theta_i' \leftarrow \tau \theta_i + (1-\tau)\theta_i',\qquad \phi' \leftarrow \tau \phi + (1-\tau)\phi' \]

Delaying the actor update lets the critics move closer to their fixed point, so the actor sees a less noisy / less biased gradient.

2) Implementation roadmap#

We will implement TD3 as a small set of building blocks:

  1. Gymnasium environment helpers (reset/step API differences)

  2. Replay buffer (NumPy storage, PyTorch sampling)

  3. Actor network \(\pi_\phi(s)\)

  4. Twin critic networks \(Q_{\theta_1}(s,a), Q_{\theta_2}(s,a)\)

  5. TD3 update step (critic update every step, actor + target update every POLICY_DELAY steps)

  6. Training loop + Plotly learning curve

def set_global_seeds(seed: int) -> None:
    np.random.seed(seed)
    if TORCH_AVAILABLE:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def env_reset(env, seed=None):
    out = env.reset(seed=seed) if seed is not None else env.reset()
    if isinstance(out, tuple) and len(out) == 2:
        obs, info = out
        return obs, info
    return out, {}


def env_step(env, action):
    out = env.step(action)
    if isinstance(out, tuple) and len(out) == 5:
        next_obs, reward, terminated, truncated, info = out
        done = bool(terminated or truncated)
        terminal = bool(terminated)  # time-limit truncation is not a terminal state
        return next_obs, float(reward), done, terminal, info
    if isinstance(out, tuple) and len(out) == 4:
        next_obs, reward, done, info = out
        terminal = bool(done)
        return next_obs, float(reward), bool(done), terminal, info
    raise ValueError('Unexpected env.step(...) output format')
if not TORCH_AVAILABLE:
    raise RuntimeError(f'PyTorch import failed: {_TORCH_IMPORT_ERROR}')

if not GYM_AVAILABLE:
    raise RuntimeError(f'Gymnasium import failed: {_GYM_IMPORT_ERROR}')


set_global_seeds(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

env = gym.make(ENV_ID)
obs, _ = env_reset(env, seed=SEED)

obs_dim = int(np.prod(env.observation_space.shape))
act_dim = int(np.prod(env.action_space.shape))

action_low = env.action_space.low.astype(np.float32)
action_high = env.action_space.high.astype(np.float32)

print('env:', ENV_ID)
print('obs_dim:', obs_dim)
print('act_dim:', act_dim)
print('action_low:', action_low)
print('action_high:', action_high)
print('device:', device)
env: Pendulum-v1
obs_dim: 3
act_dim: 1
action_low: [-2.]
action_high: [2.]
device: cpu
/home/tempa/miniconda3/lib/python3.12/site-packages/torch/cuda/__init__.py:174: UserWarning:

CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
class ReplayBuffer:
    def __init__(self, obs_dim: int, act_dim: int, size: int, seed: int, device: torch.device):
        self.obs_buf = np.zeros((size, obs_dim), dtype=np.float32)
        self.next_obs_buf = np.zeros((size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros((size, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros((size, 1), dtype=np.float32)
        self.done_buf = np.zeros((size, 1), dtype=np.float32)

        self.max_size = int(size)
        self.ptr = 0
        self.size = 0

        self.rng = np.random.default_rng(seed)
        self.device = device

    def add(self, obs, act, rew: float, next_obs, terminal: bool) -> None:
        self.obs_buf[self.ptr] = np.asarray(obs, dtype=np.float32).reshape(-1)
        self.next_obs_buf[self.ptr] = np.asarray(next_obs, dtype=np.float32).reshape(-1)
        self.act_buf[self.ptr] = np.asarray(act, dtype=np.float32).reshape(-1)
        self.rew_buf[self.ptr] = float(rew)
        self.done_buf[self.ptr] = 1.0 if terminal else 0.0

        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size: int):
        if self.size < batch_size:
            raise ValueError(f'Not enough samples: size={self.size}, batch_size={batch_size}')
        idxs = self.rng.integers(0, self.size, size=batch_size)

        obs = torch.as_tensor(self.obs_buf[idxs], device=self.device)
        act = torch.as_tensor(self.act_buf[idxs], device=self.device)
        rew = torch.as_tensor(self.rew_buf[idxs], device=self.device)
        next_obs = torch.as_tensor(self.next_obs_buf[idxs], device=self.device)
        done = torch.as_tensor(self.done_buf[idxs], device=self.device)

        return obs, act, rew, next_obs, done
def mlp(layer_sizes, activation=nn.ReLU, output_activation=nn.Identity):
    layers = []
    for i in range(len(layer_sizes) - 1):
        act = activation if i < len(layer_sizes) - 2 else output_activation
        layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
        layers.append(act())
    return nn.Sequential(*layers)


class Actor(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_sizes, action_low, action_high):
        super().__init__()
        self.net = mlp([obs_dim, *hidden_sizes, act_dim], activation=nn.ReLU, output_activation=nn.Identity)

        action_low_t = torch.as_tensor(action_low, dtype=torch.float32)
        action_high_t = torch.as_tensor(action_high, dtype=torch.float32)

        self.register_buffer('action_scale', (action_high_t - action_low_t) / 2.0)
        self.register_buffer('action_bias', (action_high_t + action_low_t) / 2.0)

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        a = torch.tanh(self.net(obs))
        return a * self.action_scale + self.action_bias


class QNetwork(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_sizes):
        super().__init__()
        self.net = mlp([obs_dim + act_dim, *hidden_sizes, 1], activation=nn.ReLU, output_activation=nn.Identity)

    def forward(self, obs: torch.Tensor, act: torch.Tensor) -> torch.Tensor:
        x = torch.cat([obs, act], dim=-1)
        return self.net(x)


class TwinCritic(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden_sizes):
        super().__init__()
        self.q1 = QNetwork(obs_dim, act_dim, hidden_sizes)
        self.q2 = QNetwork(obs_dim, act_dim, hidden_sizes)

    def forward(self, obs: torch.Tensor, act: torch.Tensor):
        return self.q1(obs, act), self.q2(obs, act)

    def q1_only(self, obs: torch.Tensor, act: torch.Tensor) -> torch.Tensor:
        return self.q1(obs, act)
class TD3Agent:
    def __init__(
        self,
        obs_dim: int,
        act_dim: int,
        hidden_sizes,
        action_low,
        action_high,
        device: torch.device,
        gamma: float = 0.99,
        tau: float = 0.005,
        actor_lr: float = 1e-3,
        critic_lr: float = 1e-3,
        policy_delay: int = 2,
        target_policy_noise: float = 0.2,
        target_noise_clip: float = 0.5,
    ):
        self.device = device

        self.actor = Actor(obs_dim, act_dim, hidden_sizes, action_low, action_high).to(device)
        self.actor_target = copy.deepcopy(self.actor).to(device)

        self.critic = TwinCritic(obs_dim, act_dim, hidden_sizes).to(device)
        self.critic_target = copy.deepcopy(self.critic).to(device)

        self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)

        self.gamma = float(gamma)
        self.tau = float(tau)

        self.policy_delay = int(policy_delay)
        self.target_policy_noise = float(target_policy_noise)
        self.target_noise_clip = float(target_noise_clip)

        self.action_low_t = torch.as_tensor(action_low, dtype=torch.float32, device=device)
        self.action_high_t = torch.as_tensor(action_high, dtype=torch.float32, device=device)

        self.total_it = 0

        # Targets start identical to online nets
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.critic_target.load_state_dict(self.critic.state_dict())

    @torch.no_grad()
    def select_action(self, obs, noise_scale: float = 0.0):
        obs_t = torch.as_tensor(np.asarray(obs, dtype=np.float32).reshape(1, -1), device=self.device)
        action = self.actor(obs_t).cpu().numpy().reshape(-1)
        if noise_scale and noise_scale > 0:
            action = action + np.random.normal(0.0, noise_scale, size=action.shape).astype(np.float32)
        action = np.clip(action, self.action_low_t.cpu().numpy(), self.action_high_t.cpu().numpy())
        return action

    def _soft_update_(self, source: nn.Module, target: nn.Module) -> None:
        with torch.no_grad():
            for p, p_targ in zip(source.parameters(), target.parameters()):
                p_targ.data.mul_(1.0 - self.tau)
                p_targ.data.add_(self.tau * p.data)

    def train_step(self, replay_buffer: ReplayBuffer, batch_size: int):
        self.total_it += 1

        obs, act, rew, next_obs, done = replay_buffer.sample(batch_size)

        # --- Critic update (every step) ---
        with torch.no_grad():
            noise = torch.randn_like(act) * self.target_policy_noise
            noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)

            next_action = self.actor_target(next_obs) + noise
            next_action = torch.max(torch.min(next_action, self.action_high_t), self.action_low_t)

            target_q1, target_q2 = self.critic_target(next_obs, next_action)
            target_q = torch.min(target_q1, target_q2)

            y = rew + (1.0 - done) * self.gamma * target_q

        current_q1, current_q2 = self.critic(obs, act)
        critic_loss = F.mse_loss(current_q1, y) + F.mse_loss(current_q2, y)

        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        info = {'critic_loss': float(critic_loss.item())}

        # --- Delayed actor + target updates ---
        if self.total_it % self.policy_delay == 0:
            actor_loss = -self.critic.q1_only(obs, self.actor(obs)).mean()
            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()

            info['actor_loss'] = float(actor_loss.item())

            self._soft_update_(self.critic, self.critic_target)
            self._soft_update_(self.actor, self.actor_target)

        return info
replay = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=BUFFER_SIZE, seed=SEED, device=device)
agent = TD3Agent(
    obs_dim=obs_dim,
    act_dim=act_dim,
    hidden_sizes=HIDDEN_SIZES,
    action_low=action_low,
    action_high=action_high,
    device=device,
    gamma=GAMMA,
    tau=TAU,
    actor_lr=ACTOR_LR,
    critic_lr=CRITIC_LR,
    policy_delay=POLICY_DELAY,
    target_policy_noise=TARGET_POLICY_NOISE,
    target_noise_clip=TARGET_NOISE_CLIP,
)

episode_returns = []
episode_lengths = []

critic_losses = []
actor_losses = []

obs, _ = env_reset(env, seed=SEED)
ep_return = 0.0
ep_len = 0

t0 = time.time()

for t in range(TOTAL_TIMESTEPS):
    if t < START_STEPS:
        action = env.action_space.sample()
    else:
        action = agent.select_action(obs, noise_scale=EXPLORATION_NOISE)

    next_obs, reward, done, terminal, _info = env_step(env, action)

    replay.add(obs, action, reward, next_obs, terminal)

    obs = next_obs
    ep_return += reward
    ep_len += 1

    if t >= UPDATE_AFTER:
        train_info = agent.train_step(replay, batch_size=BATCH_SIZE)
        critic_losses.append(train_info['critic_loss'])
        if 'actor_loss' in train_info:
            actor_losses.append(train_info['actor_loss'])

    if done:
        episode_returns.append(ep_return)
        episode_lengths.append(ep_len)

        if len(episode_returns) % 5 == 0 or not FAST_RUN:
            elapsed = time.time() - t0
            print(
                f"Episode {len(episode_returns):4d} | return {ep_return:9.1f} | len {ep_len:3d} | "
                f"t {t + 1:6d}/{TOTAL_TIMESTEPS} | elapsed {elapsed:6.1f}s"
            )

        obs, _ = env_reset(env)
        ep_return = 0.0
        ep_len = 0

env.close()

print('episodes:', len(episode_returns))
print('last return:', episode_returns[-1] if episode_returns else None)
Episode    5 | return   -1301.5 | len 200 | t   1000/10000 | elapsed    0.0s
Episode   10 | return   -1504.4 | len 200 | t   2000/10000 | elapsed   11.3s
Episode   15 | return   -1295.4 | len 200 | t   3000/10000 | elapsed   24.7s
Episode   20 | return   -1086.4 | len 200 | t   4000/10000 | elapsed   38.1s
Episode   25 | return    -817.4 | len 200 | t   5000/10000 | elapsed   51.8s
Episode   30 | return   -1039.3 | len 200 | t   6000/10000 | elapsed   65.4s
Episode   35 | return    -526.4 | len 200 | t   7000/10000 | elapsed   78.9s
Episode   40 | return    -130.1 | len 200 | t   8000/10000 | elapsed   92.5s
Episode   45 | return    -244.9 | len 200 | t   9000/10000 | elapsed  106.1s
Episode   50 | return    -247.3 | len 200 | t  10000/10000 | elapsed  119.8s
episodes: 50
last return: -247.34952581981875
# Plot episodic returns
df = pd.DataFrame(
    {
        'episode': np.arange(1, len(episode_returns) + 1),
        'return': episode_returns,
        'length': episode_lengths,
    }
)

window = min(10, max(1, len(df)))
df['return_ma'] = df['return'].rolling(window=window, min_periods=1).mean()

fig = go.Figure()
fig.add_trace(go.Scatter(x=df['episode'], y=df['return'], mode='lines+markers', name='Return'))
fig.add_trace(go.Scatter(x=df['episode'], y=df['return_ma'], mode='lines', name=f'{window}-episode MA'))
fig.update_layout(
    title=f'TD3 training on {ENV_ID}: episodic return',
    xaxis_title='Episode',
    yaxis_title='Return',
)
fig.show()

Notes, diagnostics, and common pitfalls#

  • Terminal masking: for time-limit truncation, you typically still bootstrap, so we mask only terminated (Gymnasium) rather than truncated.

  • Twin critics: the key is using the minimum only in the bootstrap target \(y\) (not necessarily everywhere).

  • Delayed updates: do not update the actor every step; it should be updated every POLICY_DELAY critic updates.

  • Target policy smoothing: the noise added to target actions is separate from exploration noise.

  • Exploration: TD3 is deterministic; you must add noise to actions during data collection.

Stable-Baselines TD3 (reference implementation)#

Stable-Baselines3 (SB3) includes a PyTorch TD3 implementation: https://stable-baselines3.readthedocs.io/en/master/modules/td3.html

This is useful as a reference and a quick way to validate your intuition against a well-tested baseline.

If you want to run it locally:

pip install stable-baselines3

If you have SB3 installed, a minimal training script looks like:

import gymnasium as gym
import numpy as np
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise

env = gym.make('Pendulum-v1')
n_actions = env.action_space.shape[-1]

action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))

model = TD3(
    policy='MlpPolicy',
    env=env,
    action_noise=action_noise,
    verbose=1,
)
model.learn(total_timesteps=100_000)

At the end of this notebook we summarize SB3’s TD3 hyperparameters.

Stable-Baselines3 TD3 hyperparameters (glossary + defaults)#

Web research source: https://stable-baselines3.readthedocs.io/en/master/modules/td3.html

Constructor signature (defaults):

TD3(policy, env, learning_rate=0.001, buffer_size=1000000, learning_starts=100, batch_size=256, tau=0.005, gamma=0.99, train_freq=1, gradient_steps=1, action_noise=None, replay_buffer_class=None, replay_buffer_kwargs=None, optimize_memory_usage=False, n_steps=1, policy_delay=2, target_policy_noise=0.2, target_noise_clip=0.5, stats_window_size=100, tensorboard_log=None, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)

Glossary:

  • policy: policy class/name (e.g., MlpPolicy, CnnPolicy).

  • env: environment instance or env ID string.

  • learning_rate (default 1e-3): Adam learning rate (SB3 uses the same LR for actor and critics).

  • buffer_size (default 1_000_000): replay buffer capacity.

  • learning_starts (default 100): number of environment steps collected before training begins.

  • batch_size (default 256): mini-batch size sampled from replay.

  • tau (default 0.005): Polyak coefficient \(\tau\) for target network updates.

  • gamma (default 0.99): discount factor \(\gamma\).

  • train_freq (default 1): how often to train (steps), or a tuple like (n, 'step') / (n, 'episode').

  • gradient_steps (default 1): gradient updates per training iteration.

  • action_noise (default None): exploration noise used when collecting data (e.g. Gaussian or OU noise).

  • policy_delay (default 2): actor/target update period \(d\) (critics update every step).

  • target_policy_noise (default 0.2): \(\sigma\) in target policy smoothing.

  • target_noise_clip (default 0.5): \(c\) in target policy smoothing (clip range).

  • replay_buffer_class (default None): custom replay buffer class.

  • replay_buffer_kwargs (default None): kwargs passed to the replay buffer.

  • optimize_memory_usage (default False): memory-efficient replay buffer variant.

  • n_steps (default 1): n-step returns (when >1 uses an n-step replay buffer).

  • stats_window_size (default 100): logging window size (episodes averaged).

  • tensorboard_log (default None): TensorBoard log directory.

  • policy_kwargs (default None): policy/network architecture options.

  • verbose (default 0): verbosity (0/1/2).

  • seed (default None): RNG seed.

  • device (default 'auto'): device selection (CPU/GPU).

  • _init_setup_model (default True): whether to build networks at init.

References#

  • Fujimoto, van Hoof, Meger (2018): Addressing Function Approximation Error in Actor-Critic Methods (TD3)

  • Stable-Baselines3 docs / source code (TD3)